package com.alibaba.simpleimage.analyze.search.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.alibaba.simpleimage.analyze.search.cluster.ClusterBuilder;
import com.alibaba.simpleimage.analyze.search.cluster.Clusterable;
import com.alibaba.simpleimage.analyze.search.cluster.impl.Cluster;
import com.alibaba.simpleimage.analyze.search.util.ClusterUtils;
import com.alibaba.simpleimage.analyze.search.util.TreeUtils;

public class KMeansTreeNode implements Clusterable, Serializable {
	private static final long serialVersionUID = 1L;
	private List<KMeansTreeNode> subNodes;

	private boolean isLeafNode = false;
	private float[] center;// The center of the item
	private int height = 0;// The depth of the node from root

	private int numSubItems;// Total number of items with a path through this
								// node, or the "weight"
	private int currentItems;// The current number of items with a path through
								// this node
	private int id = -1;// The unique id for the node in the tree, AKA the
							// "word" of the tree

	public KMeansTreeNode(float[] center, List<Clusterable> items,
			int branchFactor, int maxHeight, int height, ClusterBuilder clusterBuilder) {
		// TODO: Something about this global variable
		if (height == maxHeight || items.size() < branchFactor
				|| (getMeanDist(items, center) < 0)) {
			isLeafNode = true;
			subNodes = new ArrayList<KMeansTreeNode>(0);
			id = KMeansTree.idCount++;
		}

		else {
			Clusterable[] clusters =  clusterBuilder.collect(items, branchFactor);
			subNodes = new ArrayList<KMeansTreeNode>(branchFactor);
			for (Clusterable cluster : clusters) {
			    if(cluster instanceof Cluster)
				if (((Cluster)cluster).getItems().size() > 0) {
					KMeansTreeNode node = new KMeansTreeNode(
							((Cluster)cluster).getClusterMean(), ((Cluster)cluster).getItems(),
							branchFactor, maxHeight, height + 1, clusterBuilder);
					subNodes.add(node);
				}
			}
		}
		this.height = height;
		this.center = center;
		this.numSubItems = items.size();

	}

	private float getMeanDist(List<Clusterable> items, float[] center) {
	    float sum = 0;
		for (Clusterable clusterItem : items) {
		    float dist = ClusterUtils.getEuclideanDistance(
					clusterItem.getLocation(), center);
			sum += dist;
		}
		return sum / items.size();
	}

	public boolean isLeafNode() {
		return isLeafNode;
	}

	public List<KMeansTreeNode> getSubNodes() {
		return subNodes;
	}

	public float[] getLocation() {
		return center;
	}

	public int getNumSubItems() {
		return numSubItems;
	}

	public int getHeight() {
		return height;
	}

	public int getId() {
		return id;
	}



	/**
	 * Adds a clusterable to the current vocab tree for word creation
	 */
	public int getValueId(Clusterable c) {
		currentItems++;
		/*
		 * if(isLeafNode()) { return id; }
		 */
		int index = TreeUtils.findNearestNodeIndex(subNodes, c);
		if (index >= 0) {
			KMeansTreeNode node = subNodes.get(index);
			return node.getValueId(c);
		}
		return id;
	}

	public int getCurrentItemCount() {
		return currentItems;
	}

	public void reset() {
		currentItems = 0;
		for (KMeansTreeNode node : subNodes) {
			node.reset();
		}
	}

	@Override
	public String toString() {
		return "KMeansTreeNode [isLeafNode=" + isLeafNode + ", center="
				+ Arrays.toString(center) + ", height=" + height
				+ ", numSubItems=" + numSubItems + ", currentItems="
				+ currentItems + ", id=" + id + "]";
	}

}